# 外部から簡単にupscalerを呼ぶためのスクリプト
# 単体で動くようにモデル定義も含めている

import argparse
import glob
import os
import cv2
from diffusers import AutoencoderKL

from typing import Dict, List
import numpy as np

import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()

from torch import nn
from tqdm import tqdm
from PIL import Image
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
        super(ResidualBlock, self).__init__()

        if out_channels is None:
            out_channels = in_channels

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.relu2 = nn.ReLU(inplace=True)  # このReLUはresidualに足す前にかけるほうがいいかも

        # initialize weights
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += residual

        out = self.relu2(out)

        return out


class Upscaler(nn.Module):
    def __init__(self):
        super(Upscaler, self).__init__()

        # define layers
        # latent has 4 channels

        self.conv1 = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        self.bn1 = nn.BatchNorm2d(128)
        self.relu1 = nn.ReLU(inplace=True)

        # resblocks
        # 数の暴力で20個：次元数を増やすよりもブロックを増やしたほうがreceptive fieldが広がるはずだぞ
        self.resblock1 = ResidualBlock(128)
        self.resblock2 = ResidualBlock(128)
        self.resblock3 = ResidualBlock(128)
        self.resblock4 = ResidualBlock(128)
        self.resblock5 = ResidualBlock(128)
        self.resblock6 = ResidualBlock(128)
        self.resblock7 = ResidualBlock(128)
        self.resblock8 = ResidualBlock(128)
        self.resblock9 = ResidualBlock(128)
        self.resblock10 = ResidualBlock(128)
        self.resblock11 = ResidualBlock(128)
        self.resblock12 = ResidualBlock(128)
        self.resblock13 = ResidualBlock(128)
        self.resblock14 = ResidualBlock(128)
        self.resblock15 = ResidualBlock(128)
        self.resblock16 = ResidualBlock(128)
        self.resblock17 = ResidualBlock(128)
        self.resblock18 = ResidualBlock(128)
        self.resblock19 = ResidualBlock(128)
        self.resblock20 = ResidualBlock(128)

        # last convs
        self.conv2 = nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU(inplace=True)

        self.conv3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        self.bn3 = nn.BatchNorm2d(64)
        self.relu3 = nn.ReLU(inplace=True)

        # final conv: output 4 channels
        self.conv_final = nn.Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))

        # initialize weights
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

        # initialize final conv weights to 0: 流行りのzero conv
        nn.init.constant_(self.conv_final.weight, 0)

    def forward(self, x):
        inp = x

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        # いくつかのresblockを通した後に、residualを足すことで精度向上と学習速度向上が見込めるはず
        residual = x
        x = self.resblock1(x)
        x = self.resblock2(x)
        x = self.resblock3(x)
        x = self.resblock4(x)
        x = x + residual
        residual = x
        x = self.resblock5(x)
        x = self.resblock6(x)
        x = self.resblock7(x)
        x = self.resblock8(x)
        x = x + residual
        residual = x
        x = self.resblock9(x)
        x = self.resblock10(x)
        x = self.resblock11(x)
        x = self.resblock12(x)
        x = x + residual
        residual = x
        x = self.resblock13(x)
        x = self.resblock14(x)
        x = self.resblock15(x)
        x = self.resblock16(x)
        x = x + residual
        residual = x
        x = self.resblock17(x)
        x = self.resblock18(x)
        x = self.resblock19(x)
        x = self.resblock20(x)
        x = x + residual

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = self.bn3(x)

        # ここにreluを入れないほうがいい気がする

        x = self.conv_final(x)

        # network estimates the difference between the input and the output
        x = x + inp

        return x

    def support_latents(self) -> bool:
        return False

    def upscale(
        self,
        vae: AutoencoderKL,
        lowreso_images: List[Image.Image],
        lowreso_latents: torch.Tensor,
        dtype: torch.dtype,
        width: int,
        height: int,
        batch_size: int = 1,
        vae_batch_size: int = 1,
    ):
        # assertion
        assert lowreso_images is not None, "Upscaler requires lowreso image"

        # make upsampled image with lanczos4
        upsampled_images = []
        for lowreso_image in lowreso_images:
            upsampled_image = np.array(lowreso_image.resize((width, height), Image.LANCZOS))
            upsampled_images.append(upsampled_image)

        # convert to tensor: this tensor is too large to be converted to cuda
        upsampled_images = [torch.from_numpy(upsampled_image).permute(2, 0, 1).float() for upsampled_image in upsampled_images]
        upsampled_images = torch.stack(upsampled_images, dim=0)
        upsampled_images = upsampled_images.to(dtype)

        # normalize to [-1, 1]
        upsampled_images = upsampled_images / 127.5 - 1.0

        # convert upsample images to latents with batch size
        # logger.info("Encoding upsampled (LANCZOS4) images...")
        upsampled_latents = []
        for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)):
            batch = upsampled_images[i : i + vae_batch_size].to(vae.device)
            with torch.no_grad():
                batch = vae.encode(batch).latent_dist.sample()
            upsampled_latents.append(batch)

        upsampled_latents = torch.cat(upsampled_latents, dim=0)

        # upscale (refine) latents with this model with batch size
        logger.info("Upscaling latents...")
        upscaled_latents = []
        for i in range(0, upsampled_latents.shape[0], batch_size):
            with torch.no_grad():
                upscaled_latents.append(self.forward(upsampled_latents[i : i + batch_size]))
        upscaled_latents = torch.cat(upscaled_latents, dim=0)

        return upscaled_latents * 0.18215


# external interface: returns a model
def create_upscaler(**kwargs):
    weights = kwargs["weights"]
    model = Upscaler()

    logger.info(f"Loading weights from {weights}...")
    if os.path.splitext(weights)[1] == ".safetensors":
        from safetensors.torch import load_file

        sd = load_file(weights)
    else:
        sd = torch.load(weights, map_location=torch.device("cpu"))
    model.load_state_dict(sd)
    return model


# another interface: upscale images with a model for given images from command line
def upscale_images(args: argparse.Namespace):
    DEVICE = get_preferred_device()
    us_dtype = torch.float16  # TODO: support fp32/bf16
    os.makedirs(args.output_dir, exist_ok=True)

    # load VAE with Diffusers
    assert args.vae_path is not None, "VAE path is required"
    logger.info(f"Loading VAE from {args.vae_path}...")
    vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
    vae.to(DEVICE, dtype=us_dtype)

    # prepare model
    logger.info("Preparing model...")
    upscaler: Upscaler = create_upscaler(weights=args.weights)
    # logger.info("Loading weights from", args.weights)
    # upscaler.load_state_dict(torch.load(args.weights))
    upscaler.eval()
    upscaler.to(DEVICE, dtype=us_dtype)

    # load images
    image_paths = glob.glob(args.image_pattern)
    images = []
    for image_path in image_paths:
        image = Image.open(image_path)
        image = image.convert("RGB")

        # make divisible by 8
        width = image.width
        height = image.height
        if width % 8 != 0:
            width = width - (width % 8)
        if height % 8 != 0:
            height = height - (height % 8)
        if width != image.width or height != image.height:
            image = image.crop((0, 0, width, height))

        images.append(image)

    # debug output
    if args.debug:
        for image, image_path in zip(images, image_paths):
            image_debug = image.resize((image.width * 2, image.height * 2), Image.LANCZOS)

            basename = os.path.basename(image_path)
            basename_wo_ext, ext = os.path.splitext(basename)
            dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_lanczos4{ext}")
            image_debug.save(dest_file_name)

    # upscale
    logger.info("Upscaling...")
    upscaled_latents = upscaler.upscale(
        vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size
    )
    upscaled_latents /= 0.18215

    # decode with batch
    logger.info("Decoding...")
    upscaled_images = []
    for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)):
        with torch.no_grad():
            batch = vae.decode(upscaled_latents[i : i + args.vae_batch_size]).sample
        batch = batch.to("cpu")
        upscaled_images.append(batch)
    upscaled_images = torch.cat(upscaled_images, dim=0)

    # tensor to numpy
    upscaled_images = upscaled_images.permute(0, 2, 3, 1).numpy()
    upscaled_images = (upscaled_images + 1.0) * 127.5
    upscaled_images = upscaled_images.clip(0, 255).astype(np.uint8)

    upscaled_images = upscaled_images[..., ::-1]

    # save images
    for i, image in enumerate(upscaled_images):
        basename = os.path.basename(image_paths[i])
        basename_wo_ext, ext = os.path.splitext(basename)
        dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_upscaled{ext}")
        cv2.imwrite(dest_file_name, image)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--vae_path", type=str, default=None, help="VAE path")
    parser.add_argument("--weights", type=str, default=None, help="Weights path")
    parser.add_argument("--image_pattern", type=str, default=None, help="Image pattern")
    parser.add_argument("--output_dir", type=str, default=".", help="Output directory")
    parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
    parser.add_argument("--vae_batch_size", type=int, default=1, help="VAE batch size")
    parser.add_argument("--debug", action="store_true", help="Debug mode")

    args = parser.parse_args()
    upscale_images(args)
